Skip to content

feat: add PyTorch C++ backend for Add and Gemm#51

Merged
voltjia merged 21 commits intomasterfrom
feat/torch-backend
Apr 16, 2026
Merged

feat: add PyTorch C++ backend for Add and Gemm#51
voltjia merged 21 commits intomasterfrom
feat/torch-backend

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented Apr 13, 2026

Summary

  • Add slotted ActiveImplementationsImpl<Key, kDev, N> so add-on backends can register extra implementation indices without modifying existing device registries
  • Add WITH_TORCH CMake option that links against libtorch and compiles sources under src/torch/
  • Implement torch Add (index 1) and Gemm (index 2) using ATen's device-generic dispatch (at::add_out, at::addmm_out / at::baddbmm_out)
  • Extract TorchDeviceName<kDev> into src/torch/device_.h for reuse and add zero-copy ToAtenTensor<kDev>() conversion via at::from_blob()

voltjia added 13 commits April 15, 2026 02:41
Move the `TorchDeviceName<kDev>` template specializations from
`pybind11_utils.h` into a standalone header so they can be reused
by torch operator implementations without pulling in pybind11.
Add a third template parameter `N` (slot index, default 0) to
`ActiveImplementationsImpl`. Slot 0 holds the base/device-native
indices, and higher slots let add-on backends register extra indices
without conflicting with existing specializations. `ActiveImplementations`
flattens slots 0-3 via `Flatten`.
Add `WITH_TORCH` option that finds PyTorch via pip, links against
libtorch, and compiles sources under `src/torch/`. Pass `--with-torch`
to `generate_wrappers.py` so it scans `src/torch/` for operator
specializations.
Add `ToAtenDtype()` and `ToAtenTensor<kDev>()` in `src/torch/tensor_.h`
for zero-copy conversion from `infini::ops::Tensor` to `at::Tensor`
via `at::from_blob()`.
Register torch `Add` via `ActiveImplementationsImpl` slot 1 for all
devices. The implementation uses `at::add_out()` through ATen's
device-generic dispatch.
Register torch `Gemm` via `ActiveImplementationsImpl` slot 1 for all
devices. The implementation uses `at::addmm_out()` / `at::baddbmm_out()`
through ATen's device-generic dispatch.
Auto-detect PyTorch by attempting `import torch`. When found,
`WITH_TORCH` is enabled automatically.
`find_package(Torch)` pulls in Caffe2's cmake config, which calls
`enable_language(CUDA)` and breaks on platforms with non-standard
CUDA toolchains (e.g. Iluvatar). Query include and library paths
directly via `torch.utils.cpp_extension` instead.
…(CUDA)`

CMake 4.3+ requires `CMAKE_CUDA_ARCHITECTURES` to be set before
`enable_language(CUDA)` when using non-standard CUDA compilers like
Iluvatar's `clang++`. Without it, CMake fails to detect a default
architecture.
When `pybind11` is installed via pip but not in a standard CMake search
path, `find_package(pybind11 CONFIG)` fails. Query `python -m pybind11
--cmakedir` as a fallback to locate the package.
- Iluvatar: set `CMAKE_CUDA_ARCHITECTURES` to `OFF` instead of
  `ivcore20` (CMake 4.3 rejects non-integer architecture names; the
  architecture is already passed via `CMAKE_CUDA_FLAGS`).

- MetaX/Moore: split torch operator headers into declaration-only `.h`
  files and `.cc` implementation files with explicit instantiations.
  Compile the `.cc` files with the system `g++` instead of the vendor
  compiler (`mxcc`/`mcc`), which cannot parse vendor-forked torch
  headers in C++ extension mode.

- Cambricon: guard `UInt16`/`UInt32`/`UInt64` scalar types in
  `ToAtenDtype()` with a `TORCH_VERSION` check (these types require
  PyTorch 2.4+; Cambricon ships torch 2.1).

- Wrapper generator: scan only `.h` files to avoid including `.cc`
  explicit-instantiation files in the generated `ops.cc`.
- Iluvatar: move `-x ivcore` from CMAKE_CUDA_FLAGS to compile-only
  options so it doesn't get passed during linking (which caused
  clang++ to re-parse .o files as source code)
- MetaX: add `-DUSE_MACA=1` to g++ flags for torch source
  compilation (MetaX torch fork headers require this define)
- Cambricon: query `torch.compiled_with_cxx11_abi()` and set
  `_GLIBCXX_USE_CXX11_ABI` globally to match torch's ABI setting
  (fixes undefined reference to `c10::Device::Device(std::string)`)
Use backtick-fenced Markdown syntax for identifiers in comments
and error messages, and ensure comments are complete sentences.
@voltjia voltjia force-pushed the feat/torch-backend branch from 465ac5d to 9c04549 Compare April 15, 2026 03:02
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented Apr 15, 2026

results.log

Replace the manual `ActiveImplementationsImpl` slot system with
`std::is_base_of`-based compile-time detection. A real `Operator`
specialization inherits from `Key` (e.g., `Gemm`), while the primary
template inherits only from `OperatorBase` — SFINAE distinguishes the
two automatically, eliminating the need for `registry.h` files.
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented Apr 16, 2026

results.log

Comment thread src/torch/tensor_.h Outdated
Comment thread tests/test_gemm.py
Comment thread src/operator.h Outdated
voltjia added 5 commits April 16, 2026 07:14
Replace the hand-unrolled `Flatten<..., 0>::type, ..., 3>::type>` with
`std::index_sequence<0..kMaxImplementations>` expansion. Increase
`kMaxImplementations` from 4 to 16.
…or_.h`

Use `constexpr int kTorchVersion` and `if constexpr` instead of `#if`
macros for PyTorch version checks. Extract unsigned dtype handling into
`detail::ToAtenUnsignedDataType`.
`c10::ScalarType::UInt16` is a non-dependent name resolved at template
definition time.  Introduce `DependentScalarType<kVersion>::type` so
the enum member access becomes dependent and is properly discarded by
`if constexpr` on older PyTorch versions.
ATen `addmm`/`baddbmm` does not support `float16`/`bfloat16` on CPU.
@voltjia voltjia force-pushed the feat/torch-backend branch from d6c1a64 to 24572c8 Compare April 16, 2026 08:02
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented Apr 16, 2026

results.log

@voltjia voltjia requested a review from Ziminli April 16, 2026 08:21
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented Apr 16, 2026

results.log

@voltjia voltjia merged commit 299e858 into master Apr 16, 2026
4 checks passed
@voltjia voltjia deleted the feat/torch-backend branch April 16, 2026 09:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants